{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.sql import SparkSession"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "spark = SparkSession.builder.appName('naive_bayes').getOrCreate()  #创建SparkSession对象"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#irisdf = spark.sql(\"SELECT SepalLength, SepalWidth, PetalLength, PetalWidth, Species FROM iris\")\n",
    "irisdf=spark.read.csv('iris_dataset.csv',inferSchema=True,header=True) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(150, 5)\n"
     ]
    }
   ],
   "source": [
    "print((irisdf.count(),len(irisdf.columns)))  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------+------------------+-------------------+------------------+------------------+\n",
      "|summary|      sepal_length|        sepal_width|      petal_length|       petal_width|\n",
      "+-------+------------------+-------------------+------------------+------------------+\n",
      "|  count|               150|                150|               150|               150|\n",
      "|   mean| 5.843333333333335| 3.0540000000000007|3.7586666666666693|1.1986666666666672|\n",
      "| stddev|0.8280661279778637|0.43359431136217375| 1.764420419952262|0.7631607417008414|\n",
      "|    min|               4.3|                2.0|               1.0|               0.1|\n",
      "|    max|               7.9|                4.4|               6.9|               2.5|\n",
      "+-------+------------------+-------------------+------------------+------------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "irisdf.describe().select('summary','sepal_length','sepal_width','petal_length','petal_width').show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- sepal_length: double (nullable = true)\n",
      " |-- sepal_width: double (nullable = true)\n",
      " |-- petal_length: double (nullable = true)\n",
      " |-- petal_width: double (nullable = true)\n",
      " |-- species: string (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "irisdf.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import StringIndexer\n",
    "labelIndexer = StringIndexer(inputCol=\"species\", outputCol=\"label\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import VectorAssembler\n",
    "vecAssembler = VectorAssembler(inputCols=[\"sepal_length\", \"sepal_width\", \"petal_length\", \"petal_width\"], outputCol=\"features\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "train, test = irisdf.randomSplit([0.7, 0.3], seed = 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.classification import NaiveBayes\n",
    "from pyspark.ml import Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "nb = NaiveBayes(smoothing=1.0, modelType=\"multinomial\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline = Pipeline(stages=[labelIndexer, vecAssembler, nb])   #构建一个stages，其中包含特征转化和贝叶斯估计器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = pipeline.fit(train)   #pipeline由一系列阶段组成，每个阶段要么是估计器，要么是转换器。当调用fit，按顺序执行阶段"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions = model.transform(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- sepal_length: double (nullable = true)\n",
      " |-- sepal_width: double (nullable = true)\n",
      " |-- petal_length: double (nullable = true)\n",
      " |-- petal_width: double (nullable = true)\n",
      " |-- species: string (nullable = true)\n",
      " |-- label: double (nullable = false)\n",
      " |-- features: vector (nullable = true)\n",
      " |-- rawPrediction: vector (nullable = true)\n",
      " |-- probability: vector (nullable = true)\n",
      " |-- prediction: double (nullable = false)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "predictions.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+----------+-------------------------------------------------------------+\n",
      "|label|prediction|probability                                                  |\n",
      "+-----+----------+-------------------------------------------------------------+\n",
      "|0.0  |0.0       |[0.5112170900860498,0.42378835062808623,0.06499455928586405] |\n",
      "|0.0  |0.0       |[0.5025414372542717,0.41477046961196545,0.08268809313376278] |\n",
      "|2.0  |2.0       |[0.15134048617914386,0.0781776035394255,0.7704819102814305]  |\n",
      "|1.0  |1.0       |[0.4861766353137803,0.5021723246946296,0.011651039991590116] |\n",
      "|1.0  |1.0       |[0.4728547962801507,0.5179831760092096,0.009162027710639722] |\n",
      "|1.0  |1.0       |[0.4633454865070346,0.530831347337991,0.005823166154974355]  |\n",
      "|2.0  |2.0       |[0.24774327022067563,0.1381935673888351,0.6140631623904893]  |\n",
      "|0.0  |0.0       |[0.5200371667676991,0.4304739552641891,0.04948887796811193]  |\n",
      "|0.0  |0.0       |[0.5117917806583538,0.4527847963979916,0.03542342294365454]  |\n",
      "|1.0  |1.0       |[0.47322193519079037,0.5176674061680799,0.009110658641129801]|\n",
      "+-----+----------+-------------------------------------------------------------+\n",
      "only showing top 10 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from pyspark.sql.functions import rand, randn\n",
    "predictions.select(\"label\", \"prediction\", \"probability\").orderBy(rand()).show(10,truncate=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.evaluation import MulticlassClassificationEvaluator    #为了评估我们的模型，我们将在多类分类中使用评估器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluator = MulticlassClassificationEvaluator(labelCol=\"label\", predictionCol=\"prediction\",metricName='accuracy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracy = evaluator.evaluate(predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "预测的准确性：0.936170212766\n"
     ]
    }
   ],
   "source": [
    "print (u'预测的准确性：{}'.format(accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'metricName: metric name in evaluation (f1|weightedPrecision|weightedRecall|accuracy) (default: f1, current: accuracy)'"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluator.explainParam(\"metricName\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
